#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2018/1/31 16:58
# @Author  : Aries
# @Site    : 
# @File    : sensor_OLI.py
# @Software: PyCharm Community Edition

from Py6S import *
import datetime
import numpy as np
import osr,ogr
from utils import *

def _extractHeaderParameters(inputHeader):
    """
    Understands and parses the Landsat MTL header files
    """
    try:
        print("Reading header file")
        hFile = open(inputHeader, 'r')
        headerParams = dict()
        for line in hFile:
            line = line.strip()
            if line:
                lineVals = line.split('=')
                if len(lineVals) == 2:
                    if (lineVals[0].strip() != "GROUP") or (lineVals[0].strip() != "END_GROUP"):
                        headerParams[lineVals[0].strip()] = lineVals[1].strip().replace('"', '')
        hFile.close()
        print("Extracting Header Values")
        # Get the sensor info.
        if ((headerParams["SPACECRAFT_ID"].upper() == "LANDSAT_8") or (
            headerParams["SPACECRAFT_ID"].upper() == "LANDSAT8")) and (
            headerParams["SENSOR_ID"].upper() == "OLI_TIRS"):
            sensor = "LS8"
        else:
            raise ("Do no recognise the spacecraft and sensor or combination.")

        sensorID = headerParams["SENSOR_ID"]
        spacecraftID = headerParams["SPACECRAFT_ID"]

        # Get row/path
        row = int(headerParams["WRS_ROW"])
        path = int(headerParams["WRS_PATH"])

        # Get date and time of the acquisition
        acData = headerParams["DATE_ACQUIRED"].split('-')
        acTime = headerParams["SCENE_CENTER_TIME"].split(':')
        secsTime = acTime[2].split('.')
        acquisitionTime = datetime.datetime(int(acData[0]), int(acData[1]), int(acData[2]), int(acTime[0]),
                                                 int(acTime[1]), int(secsTime[0]))

        solarZenith = 90 - str2Float(headerParams["SUN_ELEVATION"])
        solarAzimuth = str2Float(headerParams["SUN_AZIMUTH"])

        # Get the geographic lat/long corners of the image.
        latTL = str2Float(headerParams["CORNER_UL_LAT_PRODUCT"])
        lonTL = str2Float(headerParams["CORNER_UL_LON_PRODUCT"])
        latTR = str2Float(headerParams["CORNER_UR_LAT_PRODUCT"])
        lonTR = str2Float(headerParams["CORNER_UR_LON_PRODUCT"])
        latBL = str2Float(headerParams["CORNER_LL_LAT_PRODUCT"])
        lonBL = str2Float(headerParams["CORNER_LL_LON_PRODUCT"])
        latBR = str2Float(headerParams["CORNER_LR_LAT_PRODUCT"])
        lonBR = str2Float(headerParams["CORNER_LR_LON_PRODUCT"])

        # Get the projected X/Y corners of the image
        xTL = str2Float(headerParams["CORNER_UL_PROJECTION_X_PRODUCT"])
        yTL = str2Float(headerParams["CORNER_UL_PROJECTION_Y_PRODUCT"])
        xTR = str2Float(headerParams["CORNER_UR_PROJECTION_X_PRODUCT"])
        yTR = str2Float(headerParams["CORNER_UR_PROJECTION_Y_PRODUCT"])
        xBL = str2Float(headerParams["CORNER_LL_PROJECTION_X_PRODUCT"])
        yBL = str2Float(headerParams["CORNER_LL_PROJECTION_Y_PRODUCT"])
        xBR = str2Float(headerParams["CORNER_LR_PROJECTION_X_PRODUCT"])
        yBR = str2Float(headerParams["CORNER_LR_PROJECTION_Y_PRODUCT"])

        # Get projection
        inProj = osr.SpatialReference()
        inProj.ImportFromEPSG(32643)
        # Check image is square!
        if not ((xTL == xBL) and (yTL == yTR) and (xTR == xBR) and (
            yBL == yBR)):
            raise ("Image is not square in projected coordinates.")

        xCentre = xTL + ((xTR - xTL) / 2)
        yCentre = yBR + ((yTL - yBR) / 2)

        wgs84latlonProj = osr.SpatialReference()
        wgs84latlonProj.ImportFromEPSG(4326)

        wktPt = 'POINT(%s %s)' % (xCentre, yCentre)
        # print(wktPt)
        point = ogr.CreateGeometryFromWkt(wktPt)
        point.AssignSpatialReference(inProj)
        point.TransformTo(wgs84latlonProj)
        # print(point)

        latCentre = point.GetY()
        lonCentre = point.GetX()

    except Exception as e:
        raise e
    return lonCentre, latCentre, acquisitionTime

def calc6SAtmosCorr(aeroProfile, atmosProfile, groundRefl, surfaceAltitude, toa, useBRDF, acqtime, lon, lat):
    """
    calculate 6s coeffs for each LUT combination
    :param aeroProfile:
    :param atmosProfile:
    :param groundRefl:
    :param datestr:
    :param sensor_dic:
    :param surfaceAltitude:
    :param radi:
    :param useBRDF:
    :param aot550:
    :return:
    """
    sixsCoeffs = np.zeros((6, 5), dtype=np.float32)
    #set up 6S models
    s = SixS()
    s.atmos_profile = atmosProfile
    s.aero_profile = aeroProfile
    s.ground_reflectance = groundRefl
    s.geometry = Geometry.Landsat_TM()
    s.geometry.month = acqtime.month
    s.geometry.day = acqtime.day
    s.geometry.gmt_decimal_hour = float(acqtime.hour) + float(acqtime.minute) / 60.0
    s.geometry.latitude = lat
    s.geometry.longitude = lon

    s.altitudes = Altitudes()
    s.altitudes.set_target_custom_altitude(surfaceAltitude)
    s.altitudes.set_sensor_satellite_level()

    # 气溶胶厚度使用能见度模式，默认为40 km，如有实测资料，也可使用550 nm 处的气溶胶厚度值来代替
    s.aot550 = 0.01
    if useBRDF:
        s.atmos_corr = AtmosCorr.AtmosCorrBRDFFromReflectance(toa)
    else:
        s.atmos_corr = AtmosCorr.AtmosCorrLambertianFromReflectance(toa)

    # # Band 1
    # s.wavelength = Wavelength(0.427, 0.4595,
    #                                    [0.000073, 0.001628, 0.024767, 0.254149, 0.908749, 0.977393, 0.986713, 0.993137,
    #                                     0.982780, 0.905808, 0.226412, 0.036603, 0.002414, 0.000255])
    # s.run()
    # #print(s.outputs.fulltext)
    # sixsCoeffs[0, 0] = float(s.outputs.values['coef_xa'])
    # sixsCoeffs[0, 1] = float(s.outputs.values['coef_xb'])
    # sixsCoeffs[0, 2] = float(s.outputs.values['coef_xc'])

    # Band 2
    s.wavelength = Wavelength(0.436, 0.5285,[0.000010, 0.000117, 0.000455, 0.001197, 0.006869, 0.027170, 0.271370,
                                             0.723971,0.903034, 0.909880, 0.889667, 0.877453, 0.879688, 0.891913,
                                             0.848533, 0.828339,0.868497, 0.912538, 0.931726, 0.954248, 0.956424,
                                             0.978564, 0.989469, 0.968801,0.988729, 0.967361, 0.966125, 0.981834,
                                             0.963135, 0.996498, 0.844893, 0.190738,0.005328, 0.001557, 0.000516,
                                             0.000162, 0.000023, -0.000016])
    s.run()
    sixsCoeffs[0, 0] = float(s.outputs.values['coef_xa'])
    sixsCoeffs[0, 1] = float(s.outputs.values['coef_xb'])
    sixsCoeffs[0, 2] = float(s.outputs.values['coef_xc'])
    sixsCoeffs[0, 3] = float(s.outputs.values['measured_radiance'])
    sixsCoeffs[0, 4] = float(s.outputs.values['atmos_corrected_reflectance_lambertian'])

    # Band 3
    s.wavelength = Wavelength(0.512, 0.6095,
                                       [-0.000046, 0.00011, 0.000648, 0.001332, 0.003446, 0.007024, 0.025513, 0.070551,
                                        0.353885, 0.741205, 0.954627, 0.959215, 0.969873, 0.961397, 0.977001, 0.990784,
                                        0.982642, 0.977765, 0.946245, 0.959038, 0.966447, 0.958314, 0.983397, 0.974522,
                                        0.978208, 0.974392, 0.969181, 0.982956, 0.968886, 0.986657, 0.904478, 0.684974,
                                        0.190467, 0.035393, 0.002574, 0.000394, -0.000194, -0.000292, -0.000348,
                                        -0.000351])
    s.run()
    sixsCoeffs[1, 0] = float(s.outputs.values['coef_xa'])
    sixsCoeffs[1, 1] = float(s.outputs.values['coef_xb'])
    sixsCoeffs[1, 2] = float(s.outputs.values['coef_xc'])
    sixsCoeffs[1, 3] = float(s.outputs.values['measured_radiance'])
    sixsCoeffs[1, 4] = float(s.outputs.values['atmos_corrected_reflectance_lambertian'])

    # Band 4
    s.wavelength = Wavelength(0.625, 0.690,
                                       [-0.000342, 0.000895, 0.007197, 0.030432, 0.299778, 0.764443, 0.950823, 0.951831,
                                        0.984173, 0.983434, 0.959441, 0.955548, 0.981688, 0.992388, 0.97696, 0.98108,
                                        0.980678, 0.962154, 0.966928, 0.848855, 0.123946, 0.017702, 0.001402, 0.000117,
                                        -0.000376, -0.000458, -0.000429])
    s.run()
    sixsCoeffs[2, 0] = float(s.outputs.values['coef_xa'])
    sixsCoeffs[2, 1] = float(s.outputs.values['coef_xb'])
    sixsCoeffs[2, 2] = float(s.outputs.values['coef_xc'])
    sixsCoeffs[2, 3] = float(s.outputs.values['measured_radiance'])
    sixsCoeffs[2, 4] = float(s.outputs.values['atmos_corrected_reflectance_lambertian'])

    # Band 5
    s.wavelength = Wavelength(0.829, 0.899,
                                       [-0.000034, 0.000050, 0.000314, 0.000719, 0.002107, 0.004744, 0.017346, 0.048191,
                                        0.249733, 0.582623, 0.960215, 0.973133, 1.000000, 0.980733, 0.957357, 0.947044,
                                        0.948450, 0.950632, 0.969821, 0.891066, 0.448364, 0.174619, 0.034532, 0.012440,
                                        0.002944, 0.001192, 0.000241, 0.000044, -0.000084])
    s.run()
    sixsCoeffs[3, 0] = float(s.outputs.values['coef_xa'])
    sixsCoeffs[3, 1] = float(s.outputs.values['coef_xb'])
    sixsCoeffs[3, 2] = float(s.outputs.values['coef_xc'])
    sixsCoeffs[3, 3] = float(s.outputs.values['measured_radiance'])
    sixsCoeffs[3, 4] = float(s.outputs.values['atmos_corrected_reflectance_lambertian'])

    # Band 6
    s.wavelength = Wavelength(1.515, 1.6975,
                                       [-0.00002, 0.00015, 0.00047, 0.00076, 0.00137, 0.00186, 0.00288, 0.00377,
                                        0.00553, 0.00732, 0.01099, 0.01430, 0.02183, 0.02995, 0.04786, 0.06573, 0.10189,
                                        0.13864, 0.22026, 0.29136, 0.42147, 0.52568, 0.67668, 0.75477, 0.85407, 0.89183,
                                        0.91301, 0.92295, 0.92641, 0.92368, 0.92283, 0.92206, 0.92661, 0.94253, 0.94618,
                                        0.94701, 0.95286, 0.94967, 0.95905, 0.96005, 0.96147, 0.96018, 0.96470, 0.96931,
                                        0.97691, 0.98126, 0.98861, 0.99802, 0.99964, 0.99344, 0.96713, 0.93620, 0.84097,
                                        0.75189, 0.57323, 0.45197, 0.29175, 0.21115, 0.12846, 0.09074, 0.05275, 0.03731,
                                        0.02250, 0.01605, 0.00959, 0.00688, 0.00426, 0.00306, 0.00178, 0.00124, 0.00068,
                                        0.00041, 0.00011, -0.00003])
    s.run()
    sixsCoeffs[4, 0] = float(s.outputs.values['coef_xa'])
    sixsCoeffs[4, 1] = float(s.outputs.values['coef_xb'])
    sixsCoeffs[4, 2] = float(s.outputs.values['coef_xc'])
    sixsCoeffs[4, 3] = float(s.outputs.values['measured_radiance'])
    sixsCoeffs[4, 4] = float(s.outputs.values['atmos_corrected_reflectance_lambertian'])

    # Band 7
    s.wavelength = Wavelength(2.037, 2.3545,
                                       [-0.000010, 0.000083, 0.000240, 0.000368, 0.000599, 0.000814, 0.001222, 0.001546,
                                        0.002187, 0.002696, 0.003733, 0.004627, 0.006337, 0.007996, 0.011005, 0.013610,
                                        0.018899, 0.023121, 0.032071, 0.040206, 0.056429, 0.070409, 0.100640, 0.128292,
                                        0.179714, 0.227234, 0.311347, 0.377044, 0.488816, 0.554715, 0.663067, 0.722284,
                                        0.792667, 0.836001, 0.867845, 0.886411, 0.906527, 0.911091, 0.929693, 0.936544,
                                        0.942952, 0.943194, 0.948776, 0.949643, 0.956635, 0.947423, 0.950874, 0.947014,
                                        0.957717, 0.946412, 0.951641, 0.948644, 0.940311, 0.947923, 0.938737, 0.941859,
                                        0.944482, 0.951661, 0.939939, 0.935493, 0.938955, 0.929162, 0.930508, 0.933908,
                                        0.936472, 0.933523, 0.946217, 0.955661, 0.963135, 0.964365, 0.962905, 0.962473,
                                        0.957814, 0.958041, 0.951706, 0.960212, 0.947696, 0.959060, 0.955750, 0.953245,
                                        0.966786, 0.960173, 0.977637, 0.982760, 0.985056, 0.999600, 0.992469, 0.995894,
                                        0.997261, 0.991127, 0.986037, 0.984536, 0.972794, 0.976540, 0.974409, 0.967502,
                                        0.955095, 0.955588, 0.922405, 0.894940, 0.823876, 0.744025, 0.602539, 0.502693,
                                        0.355569, 0.278260, 0.186151, 0.141435, 0.092029, 0.069276, 0.046332, 0.035634,
                                        0.024000, 0.018688, 0.012930, 0.010155, 0.007088, 0.005643, 0.003903, 0.003025,
                                        0.002047, 0.001554, 0.000974, 0.000680, 0.000320, 0.000119, -0.000134,
                                        -0.000263])
    s.run()
    sixsCoeffs[5, 0] = float(s.outputs.values['coef_xa'])
    sixsCoeffs[5, 1] = float(s.outputs.values['coef_xb'])
    sixsCoeffs[5, 2] = float(s.outputs.values['coef_xc'])
    sixsCoeffs[5, 3] = float(s.outputs.values['measured_radiance'])
    sixsCoeffs[5, 4] = float(s.outputs.values['atmos_corrected_reflectance_lambertian'])
    return sixsCoeffs